"""
Phase 1: Data Assembly — Innovation/R&D Capital Allocation Panel
================================================================
Merges full_panel.csv with WDI R&D and innovation indicators to
test whether demographic structure predicts innovation effort
and cross-border R&D-seeking capital allocation.

Output: innovation/data/processed/innovation_panel.csv
Tables: summary_statistics.md
"""

import sys
from pathlib import Path
import time

import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/innovation")
ROOT_DIR = PROJECT_DIR.parent
MULTILATERAL_DIR = ROOT_DIR / "multilateral"
PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
RAW_DIR = PROJECT_DIR / "data" / "raw"
TABLES_DIR = PROJECT_DIR / "output" / "tables"

for d in [PROCESSED_DIR, RAW_DIR, TABLES_DIR]:
    d.mkdir(parents=True, exist_ok=True)

sys.path.insert(0, str(MULTILATERAL_DIR / "src"))

OECD_38 = [
    "AUS", "AUT", "BEL", "CAN", "CHL", "COL", "CRI", "CZE", "DNK", "EST",
    "FIN", "FRA", "DEU", "GRC", "HUN", "ISL", "IRL", "ISR", "ITA", "JPN",
    "KOR", "LVA", "LTU", "LUX", "MEX", "NLD", "NZL", "NOR", "POL", "PRT",
    "SVK", "SVN", "ESP", "SWE", "CHE", "TUR", "GBR", "USA",
]

# WDI innovation indicators
WDI_INDICATORS = {
    'GB.XPD.RSDV.GD.ZS': 'rd_gdp',               # R&D expenditure (% GDP)
    'IP.PAT.RESD': 'patent_residents',              # Patent applications, residents
    'IP.PAT.NRES': 'patent_nonresidents',           # Patent applications, non-residents
    'IP.TMK.RESD': 'trademark_residents',            # Trademark applications, residents
    'TX.VAL.TECH.MF.ZS': 'hightech_exports_share',  # High-tech exports (% manufactured exports)
    'TX.VAL.TECH.CD': 'hightech_exports_usd',       # High-tech exports (current US$)
    'IP.JRN.ARTC.SC': 'scientific_articles',         # Scientific/technical journal articles
    'GB.XPD.RSDV.GD.ZS': 'rd_expenditure_gdp',      # R&D (% GDP) — same as rd_gdp
    'BX.KLT.DINV.WD.GD.ZS': 'fdi_inflows_gdp',     # FDI net inflows (% GDP)
    'BM.KLT.DINV.WD.GD.ZS': 'fdi_outflows_gdp',    # FDI net outflows (% GDP)
}


def download_wdi_indicator(indicator, name, countries, max_retries=3):
    """Download a WDI indicator via World Bank API."""
    import urllib.request
    import json

    print(f"  Downloading WDI {indicator} ({name}) ...")
    all_rows = []
    chunk_size = 15

    for i in range(0, len(countries), chunk_size):
        chunk = countries[i:i + chunk_size]
        country_str = ';'.join(chunk)
        url = (f"https://api.worldbank.org/v2/country/{country_str}/"
               f"indicator/{indicator}?format=json&per_page=10000&date=1990:2024")

        for attempt in range(max_retries):
            try:
                req = urllib.request.Request(url)
                req.add_header('User-Agent', 'Mozilla/5.0')
                with urllib.request.urlopen(req, timeout=30) as resp:
                    data = json.loads(resp.read().decode())
                if len(data) >= 2 and data[1]:
                    for obs in data[1]:
                        if obs.get('value') is not None:
                            all_rows.append({
                                'iso3': obs['countryiso3code'],
                                'year': int(obs['date']),
                                name: float(obs['value']),
                            })
                break
            except Exception as e:
                if attempt == max_retries - 1:
                    print(f"    Failed chunk {i//chunk_size}: {e}")
                time.sleep(1.5)

        time.sleep(0.5)

    if all_rows:
        df = pd.DataFrame(all_rows)
        df = df.drop_duplicates(subset=['iso3', 'year'])
        print(f"    {name}: {df['iso3'].nunique()} countries, {len(df)} obs")
        return df
    print(f"    {name}: no data retrieved")
    return pd.DataFrame(columns=['iso3', 'year', name])


def main():
    print("=" * 70)
    print("PHASE 1: Data Assembly — Innovation/R&D Panel")
    print("=" * 70)

    # ── [1] Load full_panel.csv ──
    print("\n[1] Loading full_panel.csv ...")
    full_path = MULTILATERAL_DIR / "followup" / "data" / "processed" / "full_panel.csv"
    df = pd.read_csv(full_path)
    df = df[df['year'] <= 2024].copy()
    countries = sorted(df['iso3'].unique())
    print(f"  Full panel: {df['iso3'].nunique()} countries, {len(df):,} obs")

    # ── [2] Download WDI innovation indicators ──
    print("\n[2] Downloading WDI innovation indicators ...")

    raw_cache = RAW_DIR / "wdi_innovation.csv"
    if raw_cache.exists():
        print("  Loading cached WDI innovation data ...")
        inn_df = pd.read_csv(raw_cache)
    else:
        # Deduplicate indicators (rd_gdp and rd_expenditure_gdp are same)
        unique_indicators = {}
        for indicator, name in WDI_INDICATORS.items():
            if name not in unique_indicators.values():
                unique_indicators[indicator] = name

        inn_dfs = []
        for indicator, name in unique_indicators.items():
            idf = download_wdi_indicator(indicator, name, countries)
            if len(idf) > 0:
                inn_dfs.append(idf)

        if inn_dfs:
            inn_df = inn_dfs[0]
            for idf in inn_dfs[1:]:
                inn_df = inn_df.merge(idf, on=['iso3', 'year'], how='outer')
            inn_df.to_csv(raw_cache, index=False)
        else:
            inn_df = pd.DataFrame(columns=['iso3', 'year'])

    # Merge
    if len(inn_df) > 0:
        existing = set(df.columns)
        new_cols = [c for c in inn_df.columns if c not in existing and c not in ['iso3', 'year']]
        if new_cols:
            merge_df = inn_df[['iso3', 'year'] + new_cols].drop_duplicates(
                subset=['iso3', 'year'])
            df = df.merge(merge_df, on=['iso3', 'year'], how='left')
            for col in new_cols:
                n = df[col].notna().sum()
                nc = df.loc[df[col].notna(), 'iso3'].nunique()
                print(f"  {col}: {n:,} obs, {nc} countries")

    # ── [3] Construct derived variables ──
    print("\n[3] Constructing derived variables ...")

    # Total patent applications
    if 'patent_residents' in df.columns:
        df['patent_total'] = df['patent_residents'].fillna(0) + df.get('patent_nonresidents', pd.Series(0)).fillna(0)
        df.loc[df['patent_residents'].isna() & df.get('patent_nonresidents', pd.Series(dtype=float)).isna(),
               'patent_total'] = np.nan

    # Patent intensity (per million population)
    if 'patent_total' in df.columns and 'population_weo' in df.columns:
        df['patents_per_million'] = df['patent_total'] / df['population_weo'].clip(lower=0.01)
        print(f"  patents_per_million: {df['patents_per_million'].notna().sum():,} obs")

    # Log R&D
    if 'rd_gdp' in df.columns:
        df['log_rd'] = np.log(df['rd_gdp'].clip(lower=0.001))
        df.loc[df['rd_gdp'].isna(), 'log_rd'] = np.nan

    # Log patents
    if 'patent_total' in df.columns:
        df['log_patents'] = np.log(df['patent_total'].clip(lower=1))
        df.loc[df['patent_total'].isna(), 'log_patents'] = np.nan

    # Non-resident patent share (attracting foreign innovation)
    if 'patent_nonresidents' in df.columns and 'patent_total' in df.columns:
        df['nonres_patent_share'] = df['patent_nonresidents'] / df['patent_total'].clip(lower=1)
        df.loc[df['patent_total'].isna(), 'nonres_patent_share'] = np.nan

    # ── [4] Interaction terms ──
    print("\n[4] Interaction terms ...")
    demo_vars = ['Z_1', 'Z_2', 'Z_3']

    for zv in demo_vars:
        if zv in df.columns:
            if 'kaopen' in df.columns:
                df[f'{zv}_x_kaopen'] = df[zv] * df['kaopen']

    df = df.sort_values(['iso3', 'year'])
    for zv in demo_vars:
        if zv in df.columns:
            df[f'{zv}_lag5'] = df.groupby('iso3')[zv].shift(5)
            df[f'd_{zv}'] = df.groupby('iso3')[zv].diff()

    # ── [5] Regime variables ──
    df['oecd'] = df['iso3'].isin(OECD_38).astype(int)

    if 'gdp_pc_ppp' in df.columns:
        def safe_qcut(x):
            try:
                return pd.qcut(x, 3, labels=['low', 'mid', 'high'], duplicates='drop')
            except (ValueError, IndexError):
                return pd.Series(np.nan, index=x.index)
        df['income_group'] = df.groupby('year')['gdp_pc_ppp'].transform(safe_qcut)

    # ── [6] Restrict and summarize ──
    print("\n[6] Panel summary ...")
    df = df[(df['year'] >= 1990) & (df['year'] <= 2024)].copy()

    n_total = len(df)
    n_countries = df['iso3'].nunique()
    n_rd = df['rd_gdp'].notna().sum() if 'rd_gdp' in df.columns else 0
    nc_rd = df.loc[df['rd_gdp'].notna(), 'iso3'].nunique() if 'rd_gdp' in df.columns else 0
    print(f"  Total panel: {n_countries} countries, {n_total:,} obs")
    print(f"  R&D/GDP: {nc_rd} countries, {n_rd:,} obs")

    # ── [7] Summary statistics ──
    key_vars = ['rd_gdp', 'log_rd', 'patent_total', 'patents_per_million',
                'nonres_patent_share', 'hightech_exports_share',
                'scientific_articles', 'fdi_inflows_gdp', 'fdi_outflows_gdp',
                'Z_1', 'Z_2', 'Z_3', 'old_dep', 'youth_dep',
                'rgdp_growth', 'kaopen', 'log_gdp_pc']
    key_vars = [v for v in key_vars if v in df.columns]

    md = ["# Summary Statistics — Innovation/R&D Panel\n"]
    md.append("| Variable | N | Mean | SD | Min | Max |")
    md.append("|---|---|---|---|---|---|")
    for v in key_vars:
        s = df[v].dropna()
        if len(s) > 0:
            md.append(f"| {v} | {len(s):,} | {s.mean():.3f} | {s.std():.3f} "
                      f"| {s.min():.3f} | {s.max():.3f} |")
    out = TABLES_DIR / "summary_statistics.md"
    out.write_text('\n'.join(md))
    print(f"  Saved: {out}")

    # ── [8] Save ──
    print("\n[8] Saving innovation_panel.csv ...")
    df.to_csv(PROCESSED_DIR / "innovation_panel.csv", index=False)
    print(f"  Saved: {PROCESSED_DIR / 'innovation_panel.csv'}")
    print(f"  Shape: {df.shape[0]:,} obs x {df.shape[1]} columns")

    print("\n" + "=" * 70)
    print("Phase 1 complete.")
    print("=" * 70)

    return df


if __name__ == "__main__":
    df = main()
